{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 29、PyTorch RNN的原理及其手写复现"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "记忆单元：就是存储过去历史信息的张量，\n",
    "对于时序特征，需要获取历史信息辅助决策\n",
    "\n",
    "![单向RNN](http://assets.hypervoid.top/img/2025/06/30/image-20250630145410308-ace4.png)\n",
    "![双向RNN](http://assets.hypervoid.top/img/2025/06/30/image-20250630145447236-42c9.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "不同序列建模的表现：\n",
    "\n",
    "\n",
    "![image-20250630145536378](http://assets.hypervoid.top/img/2025/06/30/image-20250630145536378-1b5d.png)\n",
    "\n",
    "delay 表示记忆单元在预测输出时看到前n帧的输入"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "RNN优缺点：\n",
    "\n",
    "- 优点：\n",
    "    1. 可以处理变长序列\n",
    "    2. 模型参数数量和序列长度无关\n",
    "    3. 计算量和序列长度呈线性（Transformer平方）\n",
    "    4. 考虑历史信息\n",
    "    5. 便于流式输出\n",
    "    6. 权重时不变：\n",
    "- 缺点：\n",
    "    1. 串行计算慢\n",
    "    2. 无法获取太长的历史信息"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "$$\n",
    "h_t = \\tanh(x_t W_{ih}^T + b_{ih} + h_{t-1} W_{hh}^T + b_{hh})\n",
    "$$\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 131,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Pytorch RNN API"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([1, 2, 3]), torch.Size([1, 1, 3]))"
      ]
     },
     "execution_count": 132,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "single_rnn = nn.RNN(4, 3,1, batch_first=True)\n",
    "input = torch.randn(1,2,4)\n",
    "output, h_n = single_rnn(input)\n",
    "output.shape, h_n.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([1, 2, 6]), torch.Size([2, 1, 3]))"
      ]
     },
     "execution_count": 133,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch_rnn =nn.RNN(4, 3,1, batch_first=True, bidirectional=True)\n",
    "input = torch.randn(1,2,4)\n",
    "output, h_n = torch_rnn(input)\n",
    "output.shape, h_n.shape # 这里输出维度变大了"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 手动实现 RNN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 134,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size, seq_len = 2, 3\n",
    "# 特征大小\n",
    "input_size, hidden_size=  2, 3\n",
    "\n",
    "input = torch.randn(batch_size, seq_len, input_size)\n",
    "h_prev = torch.zeros((1, batch_size, hidden_size))\n",
    "torch.manual_seed(20250630)\n",
    "torch.cuda.manual_seed(20250630)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 135,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([2, 3, 3]) torch.Size([1, 2, 3])\n",
      "---Parameters---\n",
      "weight_ih_l0 torch.Size([3, 2])\n",
      "weight_hh_l0 torch.Size([3, 3])\n",
      "bias_ih_l0 torch.Size([3])\n",
      "bias_hh_l0 torch.Size([3])\n",
      "---Output---\n",
      "tensor([[[ 0.7332, -0.1584, -0.4056],\n",
      "         [ 0.8495,  0.7383, -0.1428],\n",
      "         [ 0.9238,  0.2170,  0.5095]],\n",
      "\n",
      "        [[ 0.7086, -0.0690, -0.6058],\n",
      "         [ 0.8293,  0.9016, -0.0904],\n",
      "         [ 0.8779,  0.7847, -0.0404]]], grad_fn=<TransposeBackward1>)\n",
      "tensor([[[ 0.9238,  0.2170,  0.5095],\n",
      "         [ 0.8779,  0.7847, -0.0404]]], grad_fn=<StackBackward0>)\n"
     ]
    }
   ],
   "source": [
    "rnn_real = nn.RNN(input_size, hidden_size, batch_first=True)\n",
    "out_real, h_real = rnn_real(input, h_prev)\n",
    "print(out_real.shape, h_real.shape)\n",
    "print(\"---Parameters---\")\n",
    "for k, v in rnn_real.named_parameters():\n",
    "    print(k, v.shape)\n",
    "print(\"---Output---\")\n",
    "print(out_real)\n",
    "print(h_real)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "$$\n",
    "h_t = \\tanh(x_t W_{ih}^T + b_{ih} + h_{t-1} W_{hh}^T + b_{hh})\n",
    "$$\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch import Tensor\n",
    "\n",
    "\n",
    "class RnnCell(nn.Module):\n",
    "    def __init__(self, input_size, hidden_size, activ=nn.Tanh()):\n",
    "        super(RnnCell, self).__init__()\n",
    "        self.input_size = input_size\n",
    "        self.hidden_size = hidden_size\n",
    "        self.activ = activ\n",
    "        self.w_ih = nn.Parameter(torch.randn(hidden_size, input_size))\n",
    "        self.w_hh = nn.Parameter(torch.randn(hidden_size, hidden_size))\n",
    "        self.b_ih = nn.Parameter(torch.randn(hidden_size))        \n",
    "        self.b_hh = nn.Parameter(torch.randn(hidden_size))        \n",
    "        self.init_weights()\n",
    "\n",
    "    def init_weights(self):\n",
    "        nn.init.xavier_normal_(self.w_ih)\n",
    "        nn.init.xavier_normal_(self.w_hh)\n",
    "        nn.init.constant_(self.b_ih, 0.1)\n",
    "\n",
    "    def forward(self, x_t: Tensor, h_prev: Tensor) -> Tensor:\n",
    "        h = self.activ(x_t @ self.w_ih.T + self.b_ih + h_prev @ self.w_hh.T + self.b_hh)\n",
    "        return h\n",
    "\n",
    "\n",
    "class RNN(nn.Module):\n",
    "    def __init__(self, input_size, hidden_size):\n",
    "        super(RNN, self).__init__()\n",
    "        self.hidden_size = hidden_size\n",
    "        self.rnn = RnnCell(input_size, hidden_size, activ=nn.Tanh())        \n",
    "        # self.w_y = nn.Parameter(torch.randn(hidden_size, output_size))\n",
    "        # self.b_y = nn.Parameter(torch.randn(output_size))\n",
    "        self.leaky_relu = nn.LeakyReLU()\n",
    "\n",
    "    def forward(self, x: Tensor, h_prev=None):\n",
    "        batch_size_, seq_len_, _ = x.shape\n",
    "        h_out = torch.zeros((batch_size_, seq_len_, self.hidden_size))\n",
    "        if h_prev is None:        \n",
    "            h_prev = torch.zeros(batch_size, self.hidden_size).to(x.device)\n",
    "        for i in range(seq_len):\n",
    "            x_t = x[:, i, :]\n",
    "            h_prev = self.rnn(x_t, h_prev)\n",
    "            # y = self.leaky_relu((h_prev @ self.w_y) + self.b_y)\n",
    "            h_out[:, i, :] = h_prev\n",
    "        return h_out, h_prev"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 137,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([1, 2, 3]) torch.Size([2, 3, 3])\n"
     ]
    }
   ],
   "source": [
    "my_rnn = RNN(input_size, hidden_size, hidden_size)\n",
    "# for k, v in my_rnn.named_parameters():\n",
    "#     print(k, v.shape)\n",
    "out, h = my_rnn(input, h_prev)\n",
    "print(h.shape, out.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 138,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[[ 0.7332, -0.1584, -0.4056],\n",
      "         [ 0.8495,  0.7383, -0.1428],\n",
      "         [ 0.9238,  0.2170,  0.5095]],\n",
      "\n",
      "        [[ 0.7086, -0.0690, -0.6058],\n",
      "         [ 0.8293,  0.9016, -0.0904],\n",
      "         [ 0.8779,  0.7847, -0.0404]]], grad_fn=<CopySlices>)\n",
      "tensor([[[ 0.9238,  0.2170,  0.5095],\n",
      "         [ 0.8779,  0.7847, -0.0404]]], grad_fn=<TanhBackward0>)\n"
     ]
    }
   ],
   "source": [
    "# weight_ih_l0 torch.Size([3, 2])\n",
    "# weight_hh_l0 torch.Size([3, 3])\n",
    "# bias_ih_l0 torch.Size([3])\n",
    "# bias_hh_l0 torch.Size([3])\n",
    "\n",
    "my_rnn = RNN(input_size, hidden_size, hidden_size)\n",
    "my_rnn.rnn.w_ih = rnn_real.weight_ih_l0\n",
    "my_rnn.rnn.w_hh = rnn_real.weight_hh_l0\n",
    "my_rnn.rnn.b_ih = rnn_real.bias_ih_l0\n",
    "my_rnn.rnn.b_hh = rnn_real.bias_hh_l0\n",
    "# for k, v in my_rnn.named_parameters():\n",
    "#     print(k, v.shape)\n",
    "out, h = my_rnn(input, h_prev)\n",
    "print(out)\n",
    "print(h)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 141,
   "metadata": {},
   "outputs": [],
   "source": [
    "assert torch.allclose(out_real, out)\n",
    "assert torch.allclose(h_real, h)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 双向 RNN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 145,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([2, 3, 6]) torch.Size([2, 2, 3])\n",
      "---Parameters---\n",
      "weight_ih_l0 torch.Size([3, 2])\n",
      "weight_hh_l0 torch.Size([3, 3])\n",
      "bias_ih_l0 torch.Size([3])\n",
      "bias_hh_l0 torch.Size([3])\n",
      "weight_ih_l0_reverse torch.Size([3, 2])\n",
      "weight_hh_l0_reverse torch.Size([3, 3])\n",
      "bias_ih_l0_reverse torch.Size([3])\n",
      "bias_hh_l0_reverse torch.Size([3])\n",
      "---Output---\n",
      "tensor([[[ 0.2401, -0.2251,  0.4411,  0.7865, -0.2342,  0.4105],\n",
      "         [ 0.2998,  0.5483, -0.4422, -0.0885,  0.0915, -0.1140],\n",
      "         [-0.3084,  0.5297, -0.0544,  0.6101, -0.2395,  0.4977]],\n",
      "\n",
      "        [[ 0.4931, -0.4507,  0.4582,  0.8968, -0.2503,  0.4697],\n",
      "         [ 0.4876,  0.7293, -0.7174, -0.5585, -0.2563,  0.0718],\n",
      "         [ 0.6657,  0.0835, -0.4270, -0.1043, -0.2165, -0.0148]]],\n",
      "       grad_fn=<TransposeBackward1>)\n",
      "tensor([[[-0.3084,  0.5297, -0.0544],\n",
      "         [ 0.6657,  0.0835, -0.4270]],\n",
      "\n",
      "        [[ 0.7865, -0.2342,  0.4105],\n",
      "         [ 0.8968, -0.2503,  0.4697]]], grad_fn=<StackBackward0>)\n"
     ]
    }
   ],
   "source": [
    "h_prev = torch.zeros((2, batch_size, hidden_size))\n",
    "\n",
    "rnn_real = nn.RNN(input_size, hidden_size, batch_first=True, bidirectional=True)\n",
    "out_real, h_real = rnn_real(input, h_prev)\n",
    "print(out_real.shape, h_real.shape)\n",
    "print(\"---Parameters---\")\n",
    "for k, v in rnn_real.named_parameters():\n",
    "    print(k, v.shape)\n",
    "print(\"---Output---\")\n",
    "print(out_real)\n",
    "print(h_real)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dl",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
